{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:13:49.704011Z",
     "start_time": "2019-12-27T11:13:43.883398Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from utils import tag2idx, idx2tag\n",
    "from crf import Bert_BiLSTM_CRF\n",
    "model = Bert_BiLSTM_CRF(tag2idx)\n",
    "\n",
    "model.load_state_dict(torch.load('./checkpoints/finetuning/100.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:13:50.572906Z",
     "start_time": "2019-12-27T11:13:50.562794Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Bert_BiLSTM_CRF(\n",
       "  (lstm): LSTM(768, 384, num_layers=2, batch_first=True, bidirectional=True)\n",
       "  (fc): Linear(in_features=768, out_features=16, bias=True)\n",
       "  (bert): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(21128, 768)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (token_type_embeddings): Embedding(2, 768)\n",
       "      (LayerNorm): BertLayerNorm()\n",
       "      (dropout): Dropout(p=0.1)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (1): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (2): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (3): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (4): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (5): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (6): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (7): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (8): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (9): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (10): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (11): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:13:53.727403Z",
     "start_time": "2019-12-27T11:13:53.701139Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from pytorch_pretrained_bert import BertTokenizer\n",
    "bert_model = '/root/workspace/qa_project/chinese_L-12_H-768_A-12'\n",
    "tokenizer = BertTokenizer.from_pretrained(bert_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:54:45.454004Z",
     "start_time": "2019-12-27T11:54:45.450239Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['延', '更', '丹', '可', '以', '与', '大', '豆', '异', '黄', '胴', '一', '起', '吃', '吗']"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = '延更丹可以与大豆异黄胴一起吃吗'\n",
    "tokens = tokenizer.tokenize(text)\n",
    "tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:54:46.655998Z",
     "start_time": "2019-12-27T11:54:46.652696Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS]',\n",
       " '延',\n",
       " '更',\n",
       " '丹',\n",
       " '可',\n",
       " '以',\n",
       " '与',\n",
       " '大',\n",
       " '豆',\n",
       " '异',\n",
       " '黄',\n",
       " '胴',\n",
       " '一',\n",
       " '起',\n",
       " '吃',\n",
       " '吗',\n",
       " '[SEP]']"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens = ['[CLS]'] + tokens + ['[SEP]']\n",
    "tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:54:47.684187Z",
     "start_time": "2019-12-27T11:54:47.680801Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[101,\n",
       " 2454,\n",
       " 3291,\n",
       " 710,\n",
       " 1377,\n",
       " 809,\n",
       " 680,\n",
       " 1920,\n",
       " 6486,\n",
       " 2460,\n",
       " 7942,\n",
       " 5539,\n",
       " 671,\n",
       " 6629,\n",
       " 1391,\n",
       " 1408,\n",
       " 102]"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xx = tokenizer.convert_tokens_to_ids(tokens)\n",
    "xx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:54:51.083494Z",
     "start_time": "2019-12-27T11:54:51.079162Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 17])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()\n",
    "xx = torch.tensor(xx).unsqueeze(0)\n",
    "xx.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:54:52.603513Z",
     "start_time": "2019-12-27T11:54:52.577998Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1, 12, 13, 13,  3,  3,  3, 12, 13, 13, 13,  3,  3,  3,  3,  3,  2]])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = model.to(torch.device('cuda'))\n",
    "xx = xx.to(torch.device('cuda'))\n",
    "_, y_hat = model(torch.tensor(xx))\n",
    "y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:54:54.353910Z",
     "start_time": "2019-12-27T11:54:54.350222Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS]',\n",
       " 'B-DRG',\n",
       " 'I-DRG',\n",
       " 'I-DRG',\n",
       " 'O',\n",
       " 'O',\n",
       " 'O',\n",
       " 'B-DRG',\n",
       " 'I-DRG',\n",
       " 'I-DRG',\n",
       " 'I-DRG',\n",
       " 'O',\n",
       " 'O',\n",
       " 'O',\n",
       " 'O',\n",
       " 'O',\n",
       " '[SEP]']"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_tags = []\n",
    "for tag in y_hat.squeeze():\n",
    "    pred_tags.append(idx2tag[tag.item()])\n",
    "pred_tags"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:17:23.812301Z",
     "start_time": "2019-12-27T11:17:23.791831Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS] [CLS]\n",
      "你 O\n",
      "的 O\n",
      "症 O\n",
      "状 O\n",
      "要 O\n",
      "考 O\n",
      "虑 O\n",
      "为 O\n",
      "消 B-DSE\n",
      "化 I-DSE\n",
      "不 I-DSE\n",
      "良 I-DSE\n",
      "， O\n",
      "有 O\n",
      "可 O\n",
      "能 O\n",
      "是 O\n",
      "受 O\n",
      "凉 O\n",
      "或 O\n",
      "饮 O\n",
      "食 O\n",
      "不 O\n",
      "当 O\n",
      "引 O\n",
      "起 O\n",
      "的 O\n",
      "， O\n",
      "会 O\n",
      "引 O\n",
      "起 O\n",
      "大 O\n",
      "便 O\n",
      "不 O\n",
      "成 O\n",
      "型 O\n",
      "的 O\n",
      "。 O\n",
      "指 O\n",
      "导 O\n",
      "意 O\n",
      "见 O\n",
      "： O\n",
      "建 O\n",
      "议 O\n",
      "可 O\n",
      "以 O\n",
      "服 O\n",
      "用 O\n",
      "乳 B-DRG\n",
      "酸 I-DRG\n",
      "菌 I-DRG\n",
      "素 I-DRG\n",
      "片 I-DRG\n",
      "、 O\n",
      "维 B-DRG\n",
      "生 I-DRG\n",
      "素 I-DRG\n",
      "b1 I-DRG\n",
      "片 O\n",
      "和 O\n",
      "蒙 B-DRG\n",
      "脱 I-DRG\n",
      "石 I-DRG\n",
      "散 I-DRG\n",
      "， O\n",
      "注 O\n",
      "意 O\n",
      "休 O\n",
      "息 O\n",
      "， O\n",
      "多 O\n",
      "喝 O\n",
      "开 O\n",
      "水 O\n",
      "， O\n",
      "饮 O\n",
      "食 O\n",
      "以 O\n",
      "清 O\n",
      "淡 O\n",
      "为 O\n",
      "主 O\n",
      "， O\n",
      "禁 O\n",
      "辛 O\n",
      "辣 O\n",
      "刺 O\n",
      "激 O\n",
      "性 O\n",
      "食 O\n",
      "物 O\n",
      "和 O\n",
      "油 O\n",
      "腻 O\n",
      "煎 O\n",
      "炸 O\n",
      "食 O\n",
      "物 O\n",
      "。 O\n",
      "注 O\n",
      "意 O\n",
      "保 O\n",
      "暖 O\n",
      "， O\n",
      "避 O\n",
      "免 O\n",
      "着 O\n",
      "凉 O\n",
      "。 O\n",
      "[SEP] [SEP]\n"
     ]
    }
   ],
   "source": [
    "for s, tag in zip(tokens, pred_tags):\n",
    "    print(s, tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:37:01.999976Z",
     "start_time": "2019-12-27T11:37:01.995909Z"
    }
   },
   "outputs": [],
   "source": [
    "def map_ner(text, result):\n",
    "    entities = []\n",
    "    entity = None\n",
    "    for idx , st in enumerate(result):\n",
    "        if st.startswith('B'):\n",
    "            entity = {}\n",
    "            # entity['class'] = class_map[st.split('-')[-1]]\n",
    "            entity['start'] = idx\n",
    "        if entity is not None and st == 'O':\n",
    "            entity['end'] = idx\n",
    "            entity['name'] = text[entity['start']:entity['end']]\n",
    "            entities.append(entity)\n",
    "            entity = None\n",
    "    return entities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T11:37:02.462179Z",
     "start_time": "2019-12-27T11:37:02.458184Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'start': 9, 'end': 13, 'name': ['消', '化', '不', '良']},\n",
       " {'start': 50, 'end': 55, 'name': ['乳', '酸', '菌', '素', '片']},\n",
       " {'start': 56, 'end': 60, 'name': ['维', '生', '素', 'b1']},\n",
       " {'start': 62, 'end': 66, 'name': ['蒙', '脱', '石', '散']}]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "map_ner(tokens, pred_tags)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-12-27T12:01:02.424570Z",
     "start_time": "2019-12-27T12:00:57.941519Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n",
      "['罗红霉素', '头孢']\n"
     ]
    }
   ],
   "source": [
    "from predict import main\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
